from __future__ import division, print_function

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import mi_estimator
import warnings
# To prevent PIL warnings.
warnings.filterwarnings("ignore")
from torchmetrics import Accuracy
import data
from torch.autograd import Variable
from tqdm import tqdm
import torchnet as tnt
import utils
# import pickle
# import tensorflow as tf
# from art.estimators.classification import PyTorchClassifier
# from torchsummary import summary
# from autoattack.autoattack import AutoAttack
# import pytorch_lightning
# import time
# from art.estimators.classification import TensorFlowV2Classifier
# import art.attacks.evasion

######################################################################################################################################################################
###
###     Main function
###
######################################################################################################################################################################


class Manager(object):
    """Handles training and pruning."""

    def __init__(self, args, model, trainloader, valloader, testloader, advtrainloader, advvalloader, advtestloader):
        self.args = args
        self.cuda = args.cuda
        self.model = model
        # self.attack = attack
        
        self.train_data_loader = trainloader
        self.val_data_loader = valloader
        self.test_data_loader = testloader
        self.adv_train_data_loader = advtrainloader
        self.adv_val_data_loader = advvalloader
        self.adv_test_data_loader = advtestloader
        self.criterion = nn.CrossEntropyLoss()
                                                                       

    def eval(self, biases=None, adversarial=False, data="Test"):
        """Performs evaluation on either the training, validation, or testing dataloader with adversarial or normal data."""
        self.model.eval()
        error_meter = None
        
        if adversarial == False:
            if data=="Test":
                dataloader = self.test_data_loader     
            elif data=="Validation":
                dataloader = self.val_data_loader
            elif data=="Training":
                dataloader = self.train_data_loader
            else:
                print("Invalid data selection for eval()")
        else:
            if data=="Test":
                dataloader = self.adv_test_data_loader     
            elif data=="Validation":
                dataloader = self.adv_val_data_loader
            elif data=="Training":
                dataloader = self.adv_train_data_loader
            else:
                print("Invalid data selection for eval()")
        
        print('Performing eval...')
        
        for batch, label in tqdm(dataloader, desc='Eval'):
            if self.cuda:
                batch = batch.cuda()
            batch = Variable(batch, volatile=True)

            output = self.model(batch)

            # Init error meter.
            if error_meter is None:
                topk = [1]
                if output.size(1) > 5:
                    topk.append(5)
                error_meter = tnt.meter.ClassErrorMeter(topk=topk)
            error_meter.add(output.data, label)


        errors = error_meter.value()
        print('Error: ' + ', '.join('@%s=%.2f' %
                                    t for t in zip(topk, errors)))
        self.model.train()
        return errors




    def train_epoch(self, epoch_idx, optimizer, adversarial=False):
        """Trains model for one epoch."""
        if adversarial==True:
            ### As a reminder, we generate the perturbation once on the whole network rather than batchwise at each epoch
            ### This setting still observably detriments the non-robust model's accuracy and demonstrates the relevant behavior
            ###    of a semirobust network
            for batch, label in tqdm(self.adv_train_data_loader, desc='Epoch: %d ' % (epoch_idx), disable=True):
                """Runs model for one batch."""
                if self.cuda:
                    batch = batch.cuda()
                    label = label.cuda()
                batch = Variable(batch)
                label = Variable(label)
        
                # Set grads to 0.
                self.model.zero_grad()
        
                # Do forward-backward.
                output = self.model(batch)
                self.criterion(output, label).backward()
        
                # Update params.
                optimizer.step()
        
        else:
            for batch, label in tqdm(self.train_data_loader, desc='Epoch: %d ' % (epoch_idx), disable=True):
                """Runs model for one batch."""
                if self.cuda:
                    batch = batch.cuda()
                    label = label.cuda()
                batch = Variable(batch)
                label = Variable(label)
        
                # Set grads to 0.
                self.model.zero_grad()
        
                # Do forward-backward.
                output = self.model(batch)
                self.criterion(output, label).backward()
        
                # Update params.
                optimizer.step()
        
    
    
    
    def train(self, epochs, optimizer, save=True, target_accuracy=0, best_val_accuracy=0, adversarial=False):
        """Performs training."""
        best_val_accuracy = best_val_accuracy
        val_error_history = []

        ### Patience hyperparameter. If no validation accuracy improvement in 3 epochs, it loads the checkpoint and resets patience
        patience = 3
        base_path = ("./saves/" + self.args.dataset + "/" + self.args.network + "/" + self.args.attacktype + "/" + str(self.args.eps))
        os.makedirs(base_path, exist_ok=True)    
        checkpoint_path = (base_path + "/checkpoint_" + str(self.args.num_fb_layers) + "_" + str(self.args.epochs))

        if self.args.cuda:
            self.model = self.model.cuda()
        
        
        for idx in range(epochs):
            epoch_idx = idx + 1
            print('Epoch: %d' % (epoch_idx))

            self.model.train()
            
            self.train_epoch(epoch_idx, optimizer, adversarial=adversarial)
            
            ### Check validation accuracy for checkpointing
            val_errors = self.eval(adversarial=adversarial, data="Validation")
            val_error_history.append(val_errors)
            val_accuracy = 100 - val_errors[0]  # Top-1 accuracy.
            
            if val_accuracy >= best_val_accuracy:
                self.save_model(checkpoint_path)
                print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' %
                      (best_val_accuracy, val_accuracy))
                best_val_accuracy=val_accuracy
            elif patience <= 0:  
                self.load_model(torch.load(checkpoint_path))
                patience = 3
            else:
                patience -= 1
            
        print('Finished finetuning...')
        print('Best val error/accuracy: %0.2f%%, %0.2f%%' %
              (100 - best_val_accuracy, best_val_accuracy))
        print('-' * 16)
        return best_val_accuracy
    
    
    
    def train_subnetwork(self, epochs, optimizer, save=True, target_accuracy=0, adversarial=False, delta=0):
        """Performs training on only f_b, stopping if within delta of the target accuracy ACC*."""
        best_val_accuracy = 0
        best_test_accuracy = 0
        best_train_accuracy = 0

        val_error_history = []
        test_error_history = []
        train_error_history = []

        target_accuracy = target_accuracy
        
        patience = 3
        base_path = ("./saves/" + self.args.dataset + "/" + self.args.network + "/" + self.args.attacktype + "/" + str(self.args.eps))
        os.makedirs(base_path, exist_ok=True)    
        checkpoint_path = (base_path + "/checkpoint_" + str(self.args.num_fb_layers) + "_" + str(self.args.epochs))

        if self.args.cuda:
            self.model = self.model.cuda()
        
        
        for idx in range(epochs):
            epoch_idx = idx + 1
            print('Epoch: %d' % (epoch_idx))

            self.model.train()
            
            self.train_epoch(epoch_idx, optimizer, adversarial=adversarial)
            
            test_errors = self.eval(adversarial=adversarial, data="Test")
            test_error_history.append(test_errors)
            test_accuracy = 100 - test_errors[0]  # Top-1 accuracy.
            
            val_errors = self.eval(adversarial=adversarial, data="Validation")
            val_error_history.append(val_errors)
            val_accuracy = 100 - val_errors[0]  # Top-1 accuracy.
            
            train_errors = self.eval(adversarial=adversarial, data="Training")
            train_error_history.append(train_errors)
            train_accuracy = 100 - train_errors[0]  # Top-1 accuracy.
            
            
            ### Checkpointing on validation set
            if val_accuracy >= best_val_accuracy:
                self.save_model(checkpoint_path)
                print('Best subnetwork validation accuracy so far, Accuracy: %0.2f%% -> %0.2f%%' %
                      (best_val_accuracy, val_accuracy))
                best_val_accuracy=val_accuracy
            elif patience <= 0:  
                self.load_model(torch.load(checkpoint_path))
                patience = 3
            else:
                patience -= 1
            
        
            if train_accuracy > best_train_accuracy:
                best_train_accuracy = train_accuracy

            if test_accuracy > best_test_accuracy:
                best_test_accuracy = test_accuracy

            # Save best model, if required.
            if (target_accuracy - test_accuracy) < delta:
                print('Within delta, Target_accuracy: %0.2f%%, Best subnetwork test accuracy: %0.2f%%' %
                      (target_accuracy, test_accuracy))
                return best_test_accuracy, best_val_accuracy, best_train_accuracy, epoch_idx
       
        print('Finished finetuning...')
        print('Best test error/accuracy: %0.2f%%, %0.2f%%' %
              (100 - best_test_accuracy, best_test_accuracy))
        print('-' * 16)
        return best_test_accuracy, best_val_accuracy, best_train_accuracy, epoch_idx
       
    
    
    
    
    
    
    
    
    
    
    
    ### Get MI between a pair of layers in f_b
    def get_mi_estimate(self, parent_index, child_index, x_split, label_counts):
      # Use outputs of f^(j-3) to f^(j)
      mi_est = 0
      ### Gets MI for each class it seems
      n = 10
      if self.args.dataset == "CIFAR100":
        n = 100
        
      for i in range(n):
        act_parent = utils.activations(x_split[i], self.model, self.args.cuda, parent_index)
        act_child = utils.activations(x_split[i], self.model, self.args.cuda, child_index)

        
        if len(act_parent.shape) > 2:                                                                     # TEST
          act_parent = np.reshape(act_parent, (np.shape(act_parent)[0], -1))
        if len(act_child.shape) > 2:                                                                     # TEST
          act_child = np.reshape(act_child, (np.shape(act_child)[0], -1))
          
        ### multiply by the fractional occurance of the current class i
        mi_est += label_counts[i] * mi_estimator.EDGE(act_parent, act_child,                        # TEST
          normalize_epsilon=False, L_ensemble=1, stochastic=True)
      
      return mi_est

    def save_model(self, path):
        print("Saving model to path: ", path)
        torch.save(self.model.state_dict(), path)
        
    ### Reload all weights in layers of f_b only
    def load_fb(self,state_dict, f_b_start=0):
        print("f_b_start: ", f_b_start)
        with torch.no_grad():
            for name, module in enumerate(self.model.named_modules()):
                if name >= f_b_start:
                    if isinstance(module[1], nn.BatchNorm2d):
                        print(name, " ",module[0])
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                        module[1].running_mean.copy_(state_dict[(module[0] + ".running_mean")])
                        module[1].running_var.copy_(state_dict[(module[0] + ".running_var")])
                        module[1].num_batches_tracked.copy_(state_dict[(module[0] + ".num_batches_tracked")])
                    elif isinstance(module[1], nn.Conv2d):
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        print(name, " ",module[0]," ",module[1])
                        if self.args.network == "VGG16" or self.args.network == "AlexNet":
                            module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                    elif isinstance(module[1], nn.Linear):
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                        print(name, " ",module[0], " ",module[1])
            

    ### Reload all weights in layers of f_a only. This is only for debugging purposes at the moment
    def load_fa(self,state_dict, f_b_start=0):
        print("f_b_start: ", f_b_start)
        with torch.no_grad():
            for name, module in enumerate(self.model.named_modules()):
                # print(name," ", module[0]," ",module[1])
                # print(name, " ",module[0])
                if name < f_b_start:
                    if isinstance(module[1], nn.BatchNorm2d):
                        # print(name, " ",module[0])
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                        module[1].running_mean.copy_(state_dict[(module[0] + ".running_mean")])
                        module[1].running_var.copy_(state_dict[(module[0] + ".running_var")])
                        module[1].num_batches_tracked.copy_(state_dict[(module[0] + ".num_batches_tracked")])
                    elif isinstance(module[1], nn.Conv2d):
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        # print(name, " ",module[0]," ",module[1])
                        if self.args.network == "VGG16" or self.args.network == "AlexNet":
                            module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                    elif isinstance(module[1], nn.Linear):
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                        # print(name, " ",module[0], " ",module[1])
    
    def load_model(self,state_dict):
        self.model.load_state_dict(state_dict)
            
    ### Prevent training of f_a
    def freeze_fa(self, f_b_start):
        print("unfrozen layers:")
        for name, module in enumerate(self.model.named_modules()):
            if name < f_b_start:
                if isinstance(module[1], nn.Conv2d) or isinstance(module[1], nn.Linear) or isinstance(module[1], nn.BatchNorm2d):
                    # print(name," ", module[0])
                    for param in module[1].parameters():
                        param.requires_grad = False
            else:
                print(name," ", module[0])
